import torch

# 创建图片数据迭代器
def colle(batch):
    # 假设一个batch_size=2:那么batch的shape就是
    # ((picture1_tensor, picture1_label), (picture2_tensor, picture2_label))
    # 这里通过解压batch把多张图片的picture_tensor放在一起，picture_label放在一起
    imgs, targets = list(zip(*batch))
    imgs = torch.cat(imgs, dim=0)
    targets = torch.cat(targets, dim=0)
    return imgs, targets


"""

# 修改错函数可使用CrossEntropyLoss
For nn.CrossEntropyLoss the target has to be a single number 
from the interval [0, #classes] instead of a one-hot encoded target vector.
Your target is [1, 0], 
thus PyTorch thinks you want to have multiple labels per input 
which is not supported.

Replace your one-hot-encoded targets:

[1, 0] --> 0

[0, 1] --> 1

"""